library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;

use work.core_pkg.all;
use work.op_pkg.all;

entity decode is
	port (
		clk        : in  std_logic;
		res_n      : in  std_logic;
		stall      : in  std_logic;
		flush      : in  std_logic;

		-- from fetch
		pc_in      : in  pc_type;
		instr      : in  instr_type;

		-- from writeback
		reg_write  : in reg_write_type;

		-- towards next stages
		pc_out     : out pc_type;
		exec_op    : out exec_op_type;
		mem_op     : out mem_op_type;
		wb_op      : out wb_op_type;

		-- exceptions
		exc_dec    : out std_logic
	);
end entity;

architecture rtl of decode is
	subtype opcode_range_t is natural range 6 downto 0;
	constant OPC_LOAD   : std_logic_vector(opcode_range_t) := "0000011";
	constant OPC_STORE  : std_logic_vector(opcode_range_t) := "0100011";
	constant OPC_BRANCH : std_logic_vector(opcode_range_t) := "1100011";
	constant OPC_JALR   : std_logic_vector(opcode_range_t) := "1100111";
	constant OPC_JAL    : std_logic_vector(opcode_range_t) := "1101111";
	constant OPC_OP_IMM : std_logic_vector(opcode_range_t) := "0010011";
	constant OPC_OP     : std_logic_vector(opcode_range_t) := "0110011";
	constant OPC_AUIPC  : std_logic_vector(opcode_range_t) := "0010111";
	constant OPC_LUI    : std_logic_vector(opcode_range_t) := "0110111";
	constant OPC_NOP    : std_logic_vector(opcode_range_t) := "0001111";

	type internal_t is record
		pc        : pc_type;
		instr     : instr_type;
	end record;

	constant INSTR_NOP        : instr_type   := X"0000000F";
	constant INITIAL_INTERNAL : internal_t := (
		pc => (others => '0'),
		instr => INSTR_NOP
	);

	signal internal : internal_t;

	type reg_t is record
		-- regfile rd signals
		rdaddr1, rdaddr2 : reg_adr_type;
		rddata1, rddata2 : data_type;
		-- regfile wr signals
		wraddr   : reg_adr_type;
		wrdata   : data_type;
		regwrite : std_logic;
	end record;

	signal reg : reg_t;

	-- function helper range types
	subtype funct3_R_fmt_range_t is natural range 14 downto 12;
	subtype funct3_I_fmt_range_t is natural range 14 downto 12;
	subtype funct3_S_fmt_range_t is natural range 14 downto 12;
	subtype funct3_B_fmt_range_t is natural range 14 downto 12;
	subtype funct7_R_fmt_range_t is natural range 31 downto 25;

	-- functions
	subtype r_funct_range_t is natural range 3 + 7 - 1 downto 0;
	constant FUNCT_ADD  : std_logic_vector(r_funct_range_t) := "000" & "0000000";
	constant FUNCT_SUB  : std_logic_vector(r_funct_range_t) := "000" & "0100000";
	constant FUNCT_SLL  : std_logic_vector(r_funct_range_t) := "001" & "0000000";
	constant FUNCT_SLT  : std_logic_vector(r_funct_range_t) := "010" & "0000000";
	constant FUNCT_SLTU : std_logic_vector(r_funct_range_t) := "011" & "0000000";
	constant FUNCT_XOR  : std_logic_vector(r_funct_range_t) := "100" & "0000000";
	constant FUNCT_SRL  : std_logic_vector(r_funct_range_t) := "101" & "0000000";
	constant FUNCT_SRA  : std_logic_vector(r_funct_range_t) := "101" & "0100000";
	constant FUNCT_OR   : std_logic_vector(r_funct_range_t) := "110" & "0000000";
	constant FUNCT_AND  : std_logic_vector(r_funct_range_t) := "111" & "0000000";

	-- immedate functions
	subtype i_funct_range_t is natural range 3 - 1 downto 0;
	constant FUNCT_ADDI  : std_logic_vector(i_funct_range_t) := "000";
	constant FUNCT_SLLI  : std_logic_vector(i_funct_range_t) := "001";
	constant FUNCT_SLTI  : std_logic_vector(i_funct_range_t) := "010";
	constant FUNCT_SLTIU : std_logic_vector(i_funct_range_t) := "011";
	constant FUNCT_XORI  : std_logic_vector(i_funct_range_t) := "100";
	constant FUNCT_SRI   : std_logic_vector(i_funct_range_t) := "101";
	constant FUNCT_ORI   : std_logic_vector(i_funct_range_t) := "110";
	constant FUNCT_ANDI  : std_logic_vector(i_funct_range_t) := "111";

	-- store functions
	subtype s_funct_range_t is natural range 3 - 1 downto 0;
	-- byte, half word, word
	constant FUNCT_SB : std_logic_vector(s_funct_range_t) := "000";
	constant FUNCT_SH : std_logic_vector(s_funct_range_t) := "001";
	constant FUNCT_SW : std_logic_vector(s_funct_range_t) := "010";

	-- load functions
	-- byte, half word, word, unsigned byte, unsigned half word
	-- subtype i_funct_range_t already defined above
	constant FUNCT_LB  : std_logic_vector(i_funct_range_t) := "000";
	constant FUNCT_LH  : std_logic_vector(i_funct_range_t) := "001";
	constant FUNCT_LW  : std_logic_vector(i_funct_range_t) := "010";
	constant FUNCT_LBU : std_logic_vector(i_funct_range_t) := "100";
	constant FUNCT_LHU : std_logic_vector(i_funct_range_t) := "101";

	-- branch functions
	subtype b_funct_range_t is natural range 3 - 1 downto 0;
	constant FUNCT_BEQ  : std_logic_vector(b_funct_range_t) := "000";
	constant FUNCT_BNE  : std_logic_vector(b_funct_range_t) := "001";
	constant FUNCT_BLT  : std_logic_vector(b_funct_range_t) := "100";
	constant FUNCT_BGE  : std_logic_vector(b_funct_range_t) := "101";
	constant FUNCT_BLTU : std_logic_vector(b_funct_range_t) := "110";
	constant FUNCT_BGEU : std_logic_vector(b_funct_range_t) := "111";

	-- rs1/rs2 range helper types
	subtype rs1_range_t is natural range 19 downto 15;
	subtype rs2_range_t is natural range 24 downto 20;

	-- rd range helper
	subtype rd_range_t is natural range 11 downto 7;

	-- immediate calculation helper functions
	function calculate_immediate_I_fmt(instr : instr_type) return data_type is
		variable imm : data_type;
	begin
		imm(31 downto 11) := (others => instr(31));
		imm(10 downto 5) := instr(30 downto 25);
		imm(4 downto 1) := instr(24 downto 21);
		imm(0) := instr(20);
		return imm;
	end function;

	function calculate_immediate_S_fmt(instr : instr_type) return data_type is
		variable imm : data_type;
	begin
		imm(31 downto 11) := (others => instr(31));
		imm(10 downto 5) := instr(30 downto 25);
		imm(4 downto 1) := instr(11 downto 8);
		imm(0) := instr(7);
		return imm;
	end function;

	function calculate_immediate_B_fmt(instr : instr_type) return data_type is
		variable imm : data_type;
	begin
		imm(31 downto 12) := (others => instr(31));
		imm(11) := instr(7);
		imm(10 downto 5) := instr(30 downto 25);
		imm(4 downto 1) := instr(11 downto 8);
		imm(0) := '0';
		return imm;
	end function;

	function calculate_immediate_U_fmt(instr : instr_type) return data_type is
		variable imm : data_type;
	begin
		imm(31 downto 20) := instr(31 downto 20);
		imm(19 downto 12) := instr(19 downto 12);
		imm(11 downto 0) := (others => '0');
		return imm;
	end function;

	function calculate_immediate_J_fmt(instr : instr_type) return data_type is
		variable imm : data_type;
	begin
		imm(31 downto 20) := (others => instr(31));
		imm(19 downto 12) := instr(19 downto 12);
		imm(11) := instr(20);
		imm(10 downto 5) := instr(30 downto 25);
		imm(4 downto 1) := instr(24 downto 21);
		imm(0) := '0';
		return imm;
	end function;

	function calculate_imm_aluop(instr : instr_type) return alu_op_type is
		variable op : alu_op_type;
	begin
		case instr(funct3_I_fmt_range_t) is
			when FUNCT_ADDI =>
				op := ALU_ADD;
			when FUNCT_SLLI =>
				op := ALU_SLL;
			when FUNCT_SLTI =>
				op := ALU_SLT;
			when FUNCT_SLTIU =>
				op := ALU_SLTU;
			when FUNCT_XORI =>
				op := ALU_XOR;
			-- covers both SRLI and SRAI with imm[10] as a toggle
			when FUNCT_SRI =>
				if instr(30) = '1' then
					op := ALU_SRA;
				else
					op := ALU_SRL;
				end if;
			when FUNCT_ORI =>
				op := ALU_OR;
			when FUNCT_ANDI =>
				op := ALU_AND;
			-- unknown function
			when others =>
				-- is already default assigned, but just to be explicit about it
				op := ALU_NOP;
		end case;
		return op;
	end function;

	function calculate_aluop(instr : instr_type) return alu_op_type is
		variable op : alu_op_type;
	begin
		case instr(funct3_R_fmt_range_t) & instr(funct7_R_fmt_range_t) is
			when FUNCT_ADD =>
				op := ALU_ADD;
			when FUNCT_SUB =>
				op := ALU_SUB;
			when FUNCT_SLL =>
				op := ALU_SLL;
			when FUNCT_SLT =>
				op := ALU_SLT;
			when FUNCT_SLTU =>
				op := ALU_SLTU;
			when FUNCT_XOR =>
				op := ALU_XOR;
			when FUNCT_SRL =>
				op := ALU_SRL;
			when FUNCT_SRA =>
				op := ALU_SRA;
			when FUNCT_OR =>
				op := ALU_OR;
			when FUNCT_AND =>
				op := ALU_AND;
			-- unknown function
			when others =>
				-- is already default assigned, but just to be explicit about it
				op := ALU_NOP;
		end case;
		return op;
	end function;

	function calculate_branch_aluop(instr : instr_type) return alu_op_type is
		variable op : alu_op_type;
	begin
		case instr(funct3_B_fmt_range_t) is
			when FUNCT_BEQ =>
				op := ALU_SUB;
			when FUNCT_BNE =>
				op := ALU_SUB;
			when FUNCT_BLT =>
				op := ALU_SLT;
			when FUNCT_BGE =>
				op := ALU_SLT;
			when FUNCT_BLTU =>
				op := ALU_SLTU;
			when FUNCT_BGEU =>
				op := ALU_SLTU;
			when others =>
				op := ALU_NOP;
		end case;
		return op;
	end function;

	function calculate_branch_branchop(instr : instr_type) return branch_type is
		variable op : branch_type;
	begin
		case instr(funct3_B_fmt_range_t) is
			when FUNCT_BEQ =>
				op := BR_CND;
			when FUNCT_BNE =>
				op := BR_CNDI;
			when FUNCT_BLT =>
				op := BR_CNDI;
			when FUNCT_BGE =>
				op := BR_CND;
			when FUNCT_BLTU =>
				op := BR_CNDI;
			when FUNCT_BGEU =>
				op := BR_CND;
			when others =>
				op := BR_NOP;
		end case;
		return op;
	end function;

	function is_branch_funct3_valid(instr : instr_type) return std_logic is
		variable funct3 : std_logic_vector(b_funct_range_t);
		variable valid : std_logic;
	begin
		funct3 := instr(funct3_B_fmt_range_t);
		if (funct3 = FUNCT_BEQ) or (funct3 = FUNCT_BNE) or (funct3 = FUNCT_BLT) or (funct3 = FUNCT_BGE) or (funct3 = FUNCT_BLTU) or (funct3 = FUNCT_BGEU) then
			valid := '1';
		else
			valid := '0';
		end if;
		return valid;
	end function;

	function is_load_funct3_valid(instr : instr_type) return std_logic is
		variable funct3 : std_logic_vector(i_funct_range_t);
		variable valid : std_logic;
	begin
		funct3 := instr(funct3_I_fmt_range_t);
		if (funct3 = FUNCT_LB) or (funct3 = FUNCT_LH) or (funct3 = FUNCT_LBU) or (funct3 = FUNCT_LHU) or (funct3 = FUNCT_LW) then
			valid := '1';
		else
			valid := '0';
		end if;
		return valid;
	end function;

	function is_store_funct3_valid(instr : instr_type) return std_logic is
		variable funct3 : std_logic_vector(s_funct_range_t);
		variable valid : std_logic;
	begin
		funct3 := instr(funct3_S_fmt_range_t);
		if (funct3 = FUNCT_SB) or (funct3 = FUNCT_SH) or (funct3 = FUNCT_SW) then
			valid := '1';
		else
			valid := '0';
		end if;
		return valid;
	end function;

	function calculate_store_memop(instr : instr_type) return memtype_type is
		variable op : memtype_type;
	begin
		case instr(funct3_S_fmt_range_t) is
			when FUNCT_SB =>
				op := MEM_B;
			when FUNCT_SH =>
				op := MEM_H;
			-- nop is done not via op_type only
			-- so default case encompasses FUNCT_SW
			when others =>
				op := MEM_W;
		end case;
		return op;
	end function;

	function calculate_load_memop(instr : instr_type) return memtype_type is
		variable op : memtype_type;
	begin
		case instr(funct3_I_fmt_range_t) is
			when FUNCT_LB =>
				op := MEM_B;
			when FUNCT_LH =>
				op := MEM_H;
			when FUNCT_LBU =>
				op := MEM_BU;
			when FUNCT_LHU =>
				op := MEM_HU;
			-- nop is done not via op_type only
			-- so default case encompasses FUNCT_LW
			when others =>
				op := MEM_W;
		end case;
		return op;
	end function;

begin
	-- structural
	regfile_inst : entity work.regfile
	port map (
		clk   => clk,
		res_n => res_n,
		stall => stall,

		rdaddr1 => reg.rdaddr1,
		rdaddr2 => reg.rdaddr2,
		rddata1 => reg.rddata1,
		rddata2 => reg.rddata2,

		wraddr   => reg.wraddr,
		wrdata   => reg.wrdata,
		regwrite => reg.regwrite
	);

	-- concurrent
	reg.rdaddr1 <= instr(rs1_range_t);
	reg.rdaddr2 <= instr(rs2_range_t);
	reg.regwrite <= reg_write.write;
	reg.wraddr <= reg_write.reg;
	reg.wrdata <= reg_write.data;

	-- sequential
	sync : process(clk, res_n, flush, stall)
	begin
		if res_n = '0' then
			internal <= INITIAL_INTERNAL;
		elsif stall = '0' and rising_edge(clk) then
			internal.pc <= pc_in;
			internal.instr <= instr;
		end if;
	end process;

	async : process(all)
	begin
		-- default assignments
		pc_out <= internal.pc;
		exc_dec <= '0';

		-- exec_op
		exec_op.aluop <= ALU_NOP;
		exec_op.alusrc1 <= '0';
		exec_op.alusrc2 <= '0';
		exec_op.alusrc3 <= '0';
		exec_op.rs1 <= internal.instr(rs1_range_t);
		exec_op.rs2 <= internal.instr(rs2_range_t);
		exec_op.readdata1 <= reg.rddata1;
		exec_op.readdata2 <= reg.rddata2;
		exec_op.imm <= (others => '0');
	
		-- mem_op
		mem_op.branch <= BR_NOP;
		mem_op.mem <= MEMU_NOP;

		-- wb_op
		wb_op <= WB_NOP;
		wb_op.rd <= internal.instr(rd_range_t);
		if flush = '0' then
			-- determine instr by opcode
			case internal.instr(opcode_range_t) is
				when OPC_LUI =>
					exec_op.alusrc2 <= '1';
					exec_op.imm <= calculate_immediate_U_fmt(internal.instr);
					wb_op.write <= '1';
					wb_op.src <= WBS_ALU;

				when OPC_AUIPC =>
					exec_op.aluop <= ALU_ADD;
					exec_op.alusrc1 <= '1';
					exec_op.alusrc2 <= '1';
					exec_op.imm <= calculate_immediate_U_fmt(internal.instr);
					wb_op.write <= '1';
					wb_op.src <= WBS_ALU;

				when OPC_JAL =>
					exec_op.alusrc3 <= '1';
					exec_op.imm <= calculate_immediate_J_fmt(internal.instr);
					mem_op.branch <= BR_BR;
					wb_op.write <= '1';
					wb_op.src <= WBS_OPC;

				when OPC_JALR =>
					exec_op.aluop <= ALU_NOP;
					exec_op.alusrc1 <= '1';
					exec_op.alusrc2 <= '1';
					exec_op.alusrc3 <= '1';
					exec_op.imm <= calculate_immediate_I_fmt(internal.instr);
					mem_op.branch <= BR_BR;
					wb_op.write <= '1';
					wb_op.src <= WBS_OPC;
					if not(internal.instr(funct3_I_fmt_range_t) = "000") then
						exc_dec <= '1';
					else
						exc_dec <= '0';
					end if;

				when OPC_BRANCH =>
					exec_op.aluop <= calculate_branch_aluop(internal.instr);
					exec_op.alusrc3 <= '1';
					exec_op.imm <= calculate_immediate_B_fmt(internal.instr);
					mem_op.branch <= calculate_branch_branchop(internal.instr);
					exc_dec <= not(is_branch_funct3_valid(internal.instr));

				when OPC_LOAD =>
					exec_op.aluop <= ALU_ADD;
					exec_op.alusrc2 <= '1';
					exec_op.imm <= calculate_immediate_I_fmt(internal.instr);
					mem_op.mem.memread <= '1';
					mem_op.mem.memwrite <= '0';
					mem_op.mem.memtype <= calculate_load_memop(internal.instr);
					wb_op.write <= '1';
					wb_op.src <= WBS_MEM;
					exc_dec <= not(is_load_funct3_valid(internal.instr));

				when OPC_STORE =>
					exec_op.aluop <= ALU_ADD;
					exec_op.alusrc2 <= '1';
					exec_op.imm <= calculate_immediate_S_fmt(internal.instr);
					mem_op.mem.memread <= '0';
					mem_op.mem.memwrite <= '1';
					mem_op.mem.memtype <= calculate_store_memop(internal.instr);
					exc_dec <= not(is_store_funct3_valid(internal.instr));

				when OPC_OP_IMM =>
					exec_op.aluop <= calculate_imm_aluop(internal.instr);
					exec_op.alusrc2 <= '1';
					exec_op.imm <= calculate_immediate_I_fmt(internal.instr);
					wb_op.write <= '1';
					wb_op.src <= WBS_ALU;

				when OPC_OP =>
					exec_op.aluop <= calculate_aluop(internal.instr);
					if (calculate_aluop(internal.instr) = ALU_NOP) then
						exc_dec <= '1';
					else
						exc_dec <= '0';
					end if;
					wb_op.write <= '1';
					wb_op.src <= WBS_ALU;

				-- nop opcode
				-- case exists so that it will not register as an exception case
				when OPC_NOP =>
					-- except if funct3 is wrong
					if not(internal.instr(funct3_I_fmt_range_t) = "000") then
						exc_dec <= '1';
					else
						exc_dec <= '0';
					end if;
					null;

				-- unrecognized opcodes throw an exception
				when others =>
					exc_dec <= '1';
				
			end case;
		end if;
	end process;
end architecture;
